-
Notifications
You must be signed in to change notification settings - Fork 11
WIP - removeBinary #414
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: feature/onnx-to-tosa
Are you sure you want to change the base?
WIP - removeBinary #414
Conversation
|
I added @xiaohanAMD as reviewer too since he wrote the initial pattern and some improvements |
tvivies-amd
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A couple of small comments, I like the usage of if constexpr and the state structure.
| } | ||
| state.kValue = kValueOpt.value(); | ||
|
|
||
| // Debug - To be removed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally you would not remove the debug ouptut that you used during the implementation, I would prefer if you could log them instead. @ehsan-toosi do you know if there is any logging capabilities in this repo ?
| if (failed(match_qdq(state, dqOp1, dqOp2))) | ||
| return failure(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: add bracket here too, in order to keep the coding style consistent
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
| // Debug - To be removed | ||
| llvm::outs() << "B. SUCCESS\n"; | ||
| llvm::outs() << "kValue = " << state.kValue << "\n"; | ||
| printOnnxNodeName(binaryOp, "[RemoveBinary] matched"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can put this in the match_qdq method and return directly the output of the call to match_qdq
| for (Value res : op->getResults()) { | ||
| for (Operation *user : res.getUsers()) { | ||
| if (auto q = dyn_cast<ONNXQuantizeLinearOp>(user)) { | ||
| quantOutputOp = q; | ||
| break; | ||
| } | ||
| } | ||
| if (quantOutputOp) | ||
| break; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the binary op is a qdq operation, I would expect that the output of the binOp only has one user that is the QuantizeLinear operation. This is what the python implementation is expecting and it is fine if we have the same check here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok.
| .add<FoldBinaryThroughQDQ<ONNXDivOp>, FoldBinaryThroughQDQ<ONNXSubOp>, | ||
| FoldBinaryThroughQDQ<ONNXMulOp>, FoldBinaryThroughQDQ<ONNXAddOp>>( | ||
| &getContext()); | ||
| if (failed(applyPatternsAndFoldGreedily(function, std::move(patterns)))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if we need to use the greedy approach here, since we are not creating any new ops that would needs to be visited by the pass and same for the modified ops they do not need to be matched again. See documentation of the walker configuration: https://mlir.llvm.org/docs/PatternRewriter/#walk-pattern-rewrite-driver
What do you think ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tried replacing applyPatternsAndFoldGreedily with walkAndApplyPatterns. But the code is crashing for some test cases. Not sure why. Therefore, restored greedy driver again.
|
this pattern is very complicated when there is fork in the match chain. Please add test case to fix this scenario, here Q1 has a fork, we expect not to fold into DQ1, but in Q2: |
@xiaohanAMD |
I didn't see that, sure, let's implement the basic in this PR. More complicate case we can do later. |
REMOVE_BINARY summary:
This RemoveBinary class is a pattern-matching and rewrite utility that finds and removes binary ops (Add, Sub, Mul, Div) in quantized ONNX graphs when one input is effectively a constant. It folds the constant into the quantization parameters instead of keeping an explicit binary op in the graph.
This class rewrites patterns like: